# Copyright 2022 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from pathlib import Path
import platform
import subprocess

parser = argparse.ArgumentParser()
parser.add_argument(
    "--flatc", help="path of the Flat C compiler relative to the root directory"
)

args = parser.parse_args()

# Get the path where this script is located so we can invoke the script from
# any directory and have the paths work correctly.
script_path = Path(__file__).parent.resolve()

# Get the root path as an absolute path, so all derived paths are absolute.
root_path = script_path.parent.parent.absolute()

# Get the location of the flatc executable, reading from the first command line
# argument or defaulting to default names.
flatc_exe = Path(
    ("flatc" if not platform.system() == "Windows" else "flatc.exe")
    if not args.flatc
    else args.flatc
)

# Find and assert flatc compiler is present.
if root_path in flatc_exe.parents:
  flatc_exe = flatc_exe.relative_to(root_path)
flatc_path = Path(root_path, flatc_exe)
assert flatc_path.exists(), "Cannot find the flatc compiler " + str(flatc_path)


# Execute the flatc compiler with the specified parameters
def flatc(options, cwd=script_path):
  cmd = [str(flatc_path)] + options
  subprocess.check_call(cmd, cwd=str(cwd))


def reflection_fbs_path():
  return Path(root_path).joinpath("reflection", "reflection.fbs")


def make_absolute(filename, path=script_path):
  return str(Path(path, filename).absolute())


def assert_file_exists(filename, path=script_path):
  file = Path(path, filename)
  assert file.exists(), "could not find file: " + filename
  return file


def assert_file_doesnt_exists(filename, path=script_path):
  file = Path(path, filename)
  assert not file.exists(), "file exists but shouldn't: " + filename
  return file


def get_file_contents(filename, path=script_path):
  file = Path(path, filename)
  contents = ""
  with open(file) as file:
    contents = file.read()
  return contents


def assert_file_contains(file, needles):
  with open(file) as file:
    contents = file.read()
    for needle in [needles] if isinstance(needles, str) else needles:
      assert needle in contents, (
          "coudn't find '" + needle + "' in file: " + str(file)
      )
  return file


def assert_file_doesnt_contains(file, needles):
  with open(file) as file:
    contents = file.read()
    for needle in [needles] if isinstance(needles, str) else needles:
      assert needle not in contents, (
          "Found unexpected '" + needle + "' in file: " + str(file)
      )
  return file


def assert_file_and_contents(
    file, needle, doesnt_contain=None, path=script_path, unlink=True
):
  assert_file_contains(assert_file_exists(file, path), needle)
  if doesnt_contain:
    assert_file_doesnt_contains(assert_file_exists(file, path), doesnt_contain)
  if unlink:
    Path(path, file).unlink()


def run_all(*modules):
  failing = 0
  passing = 0
  for module in modules:
    methods = [
        func
        for func in dir(module)
        if callable(getattr(module, func)) and not func.startswith("__")
    ]
    module_failing = 0
    module_passing = 0
    for method in methods:
      try:
        print("{0}.{1}".format(module.__name__, method))
        getattr(module, method)(module)
        print(" [PASSED]")
        module_passing = module_passing + 1
      except Exception as e:
        print(" [FAILED]: " + str(e))
        module_failing = module_failing + 1
    print(
        "{0}: {1} of {2} passsed".format(
            module.__name__, module_passing, module_passing + module_failing
        )
    )
    passing = passing + module_passing
    failing = failing + module_failing
  return passing, failing
